"""
Filename: welfare_cost_lumpsum.jl
Paper: Unemployment and the Distribution of Liquidity
Authors: Zach Bethune and Guillaume Rocheteau
Contact: bethune@rice.edu
Last Modified: 8/22/2023
Purpose: functions used to compute welfare cost of a change in money growth rate.
"""
###############################################################################
# Functions used in welfare calculation
###############################################################################
#Solve steady state value functions
function W_solve_ss(Wprime;mm::model_outcomes,parms::parmsmodel,tol=1e-7)
    error=1.0; iter=0; iter_max=100000;
    W=zeros((Na,2,Nz));
    while error>tol && iter<iter_max
        iter=iter+1;
        W[:,:,:] = W_iter(Wprime; py_prime=copy(mm.py),tau=copy(mm.tau[1,1,1]),wages_bar=mm.wages_bar,theta_prime=mm.theta,am_p=mm.am_p,a_p=mm.a_p,ystar=mm.ystar,c=mm.c,agrid=mm.agrid,parms)
        error = maximum(abs.(W.-Wprime))
        Wprime = copy(W);
    end
    return W
end

#Solve steady state value functions under distorted consumption decisions by Delt
function W_solve_Delt(Wprime,Delt;mm::model_outcomes,parms::parmsmodel,tol=1e-10)
    error=1.0; iter=0; iter_max=100000;
    W=zeros((Na,2,Nz));
    while error>tol && iter<iter_max
        iter=iter+1;
        W[:,:,:] = W_iter_Delt(Wprime; py_prime=copy(mm.py),tau=copy(mm.tau[1,1,1]),wages_bar=mm.wages_bar,theta_prime=mm.theta,am_p=mm.am_p,a_p=mm.a_p,ystar=mm.ystar,c=mm.c,agrid=mm.agrid,parms,Delt)
        error = maximum(abs.(W.-Wprime))
        Wprime = copy(W);
    end
    return W
end

#Function to iterate value function once
function W_iter(Wprime; py_prime,tau,wages_bar,theta_prime,am_p,a_p,ystar,c,agrid,parms::parmsmodel)
    """
    purpose: given prices Rm_t,Rf_t,py_{t+1}, policies am'_t, a'_t, ystar_{t+1}, c_t, exog objects wages_t, tau_t, theta_{t+1}, and value W_{t+1}, compute W_t
    """

    #employment transition matrix
    PI = [(1.0-parms.DELTA) parms.DELTA; λ(theta_prime;parms) 1.0-λ(theta_prime;parms)]

    #wages  
    wages = wages_bar .+ tau

    #update W(a,z,e)
    W = zeros((Na,2,Nz));
    for j=1:2 #employment this period
        for k=1:Nz #permanent z
            for i=1:Na #beginning of period wealth
                for l=1:2 #employment next period
                    #interpolate Wprime, ystar
                    Wprime_int = LinearInterpolation(agrid,Wprime[:,l,k],extrapolation_bc=Line());
                    ystar_int = LinearInterpolation(agrid,ystar[:,l,k],extrapolation_bc=Line());

                    W[i,j,k] += parms.BETA*PI[j,l]*(1.0-parms.ALPHA)*Wprime_int(a_p[i,j,k])
                    y_m=ym(ystar_int,a_p[i,j,k],am_p[i,j,k],py_prime);
                    W[i,j,k] += parms.BETA*PI[j,l]*parms.ALPHA*(1-parms.ALPHAmf)*(υ(y_m;parms)+Wprime_int(a_p[i,j,k]-y_m*py_prime))
                    y_mf=ymf(ystar_int,a_p[i,j,k],am_p[i,j,k],py_prime)[1];
                    W[i,j,k] += parms.BETA*PI[j,l]*parms.ALPHA*parms.ALPHAmf*(υ(y_mf;parms)+Wprime_int(a_p[i,j,k]-y_mf*py_prime));
                end
                W[i,j,k] += u(c[i,j,k];parms)
            end
        end
    end
    return W
end

#Function to iterate value function under distorted consumption decisions by Delt
function W_iter_Delt(Wprime; py_prime,tau,wages_bar,theta_prime,am_p,a_p,ystar,c,agrid,parms::parmsmodel,Delt)
    """
    purpose: given prices Rm_t,Rf_t,py_{t+1}, policies am'_t, a'_t, ystar_{t+1}, c_t, exog objects wages_t, tau_t, theta_{t+1}, and value W_{t+1}, compute W_t
    """

    #employment transition matrix
    PI = [(1.0-parms.DELTA) parms.DELTA; λ(theta_prime;parms) 1.0-λ(theta_prime;parms)]

    #wages  
    wages = wages_bar .+ tau

    #update W(a,z,e)
    W = zeros((Na,2,Nz));
    for j=1:2 #employment this period
        for k=1:Nz #permanent z
            for i=1:Na #beginning of period wealth
                for l=1:2 #employment next period
                    #interpolate Wprime, ystar
                    Wprime_int = LinearInterpolation(agrid,Wprime[:,l,k],extrapolation_bc=Line());
                    ystar_int = LinearInterpolation(agrid,ystar[:,l,k],extrapolation_bc=Line());

                    W[i,j,k] += parms.BETA*PI[j,l]*(1.0-parms.ALPHA)*Wprime_int(a_p[i,j,k])
                    y_m=ym(ystar_int,a_p[i,j,k],am_p[i,j,k],py_prime);
                    W[i,j,k] += parms.BETA*PI[j,l]*parms.ALPHA*(1-parms.ALPHAmf)*(υ(y_m*Delt;parms)+Wprime_int(a_p[i,j,k]-y_m*py_prime))
                    y_mf=ymf(ystar_int,a_p[i,j,k],am_p[i,j,k],py_prime)[1];
                    W[i,j,k] += parms.BETA*PI[j,l]*parms.ALPHA*parms.ALPHAmf*(υ(y_mf*Delt;parms)+Wprime_int(a_p[i,j,k]-y_mf*py_prime));
                end
                W[i,j,k] += u(c[i,j,k]*Delt;parms)
            end
        end
    end
    return W
end

#Compute consumption equivalent welfare
function Welfare_avg_CE(Wnew0, gss, Wbase0_Delt0; mm_base::model_outcomes, parms_base::parmsmodel)
    """
    purpose: Solve consumption equivalent welfare
    """
    #Average welfare under new policy
    Wnew0_avg = sum(gss.*Wnew0);

    #Create function to compute welfare diff given CE,  Delt
    function Wel_diff(Delt)
        #compute updated value function
        W_Delt = W_solve_Delt(Wbase0_Delt0,Delt;mm=mm_base, parms=parms_base,tol=1e-7)

        #compte average welfare in base steady state given Delt
        Wbase0_avg = sum(mm_base.g.*W_Delt)

        return Wbase0_avg - Wnew0_avg
    end

    Delt = find_zero(Wel_diff,(0.8,1.2),Roots.Brent(),verbose=false,atol=1e-4)

    return Delt
end

#Compute consumption equivalent welfare by household state
function Welfare_bystate_CE(Wnew0,Wbase0_Delt0,aind,eind,kind; mm_base::model_outcomes, parms_base::parmsmodel)
    """
    purpose: Solve consumption equivalent welfare by state
    """

    #Create function to compute welfare diff given CE,  Delt
    function Wel_diff(Delt)
        #compute updated value function
        W_Delt = W_solve_Delt(Wbase0_Delt0,Delt;mm=mm_base, parms=parms_base,tol=1e-7)

        #compte average welfare in base steady state given Delt
        Wbase0 = W_Delt[aind,eind,kind]

        return Wbase0 - Wnew0
    end

    Delt = find_zero(Wel_diff,(0.8,1.2),Roots.Brent(),verbose=false,atol=1e-4)

    return Delt
end

#END